/*
 * Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_common.h
 * Copyright (c) 2023, Tri Dao.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Not a contribution
 * Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
 * NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
 * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace tensorrt_llm
{
namespace kernels
{

#define MAX_DSTATE 256

inline __device__ float2 operator+(const float2& a, const float2& b)
{
    return {a.x + b.x, a.y + b.y};
}

inline __device__ float3 operator+(const float3& a, const float3& b)
{
    return {a.x + b.x, a.y + b.y, a.z + b.z};
}

inline __device__ float4 operator+(const float4& a, const float4& b)
{
    return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h

/// @param COND       - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ...       - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
///     some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...)                                                                             \
    [&]                                                                                                                \
    {                                                                                                                  \
        if (COND)                                                                                                      \
        {                                                                                                              \
            constexpr bool CONST_NAME = true;                                                                          \
            return __VA_ARGS__();                                                                                      \
        }                                                                                                              \
        else                                                                                                           \
        {                                                                                                              \
            constexpr bool CONST_NAME = false;                                                                         \
            return __VA_ARGS__();                                                                                      \
        }                                                                                                              \
    }()

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int BYTES>
struct BytesToType
{
};

template <>
struct BytesToType<16>
{
    using Type = uint4;
    static_assert(sizeof(Type) == 16);
};

template <>
struct BytesToType<8>
{
    using Type = uint64_t;
    static_assert(sizeof(Type) == 8);
};

template <>
struct BytesToType<4>
{
    using Type = uint32_t;
    static_assert(sizeof(Type) == 4);
};

template <>
struct BytesToType<2>
{
    using Type = uint16_t;
    static_assert(sizeof(Type) == 2);
};

template <>
struct BytesToType<1>
{
    using Type = uint8_t;
    static_assert(sizeof(Type) == 1);
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename scalar_t, int N>
struct Converter
{
    static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N])
    {
#pragma unroll
        for (int i = 0; i < N; ++i)
        {
            dst[i] = src[i];
        }
    }
};

template <int N>
struct Converter<half, N>
{
    static inline __device__ void to_float(const half (&src)[N], float (&dst)[N])
    {
        static_assert(N % 2 == 0);
        auto& src2 = reinterpret_cast<const half2(&)[N / 2]>(src);
        auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
#pragma unroll
        for (int i = 0; i < N / 2; ++i)
        {
            dst2[i] = __half22float2(src2[i]);
        }
    }
};

#if __CUDA_ARCH__ >= 800
template <int N>
struct Converter<__nv_bfloat16, N>
{
    static inline __device__ void to_float(const __nv_bfloat16 (&src)[N], float (&dst)[N])
    {
        static_assert(N % 2 == 0);
        auto& src2 = reinterpret_cast<const nv_bfloat162(&)[N / 2]>(src);
        auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
#pragma unroll
        for (int i = 0; i < N / 2; ++i)
        {
            dst2[i] = __bfloat1622float2(src2[i]);
        }
    }
};
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename scalar_t>
struct SSMScanOp;

template <>
struct SSMScanOp<float>
{
    __device__ __forceinline__ float2 operator()(const float2& ab0, const float2& ab1) const
    {
        return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
    }
};

// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template <typename scalar_t>
struct SSMScanPrefixCallbackOp
{
    using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
    scan_t running_prefix;

    // Constructor
    __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_)
        : running_prefix(running_prefix_)
    {
    }

    // Callback operator to be entered by the first warp of threads in the block.
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
    __device__ scan_t operator()(scan_t block_aggregate)
    {
        scan_t old_prefix = running_prefix;
        running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
        return old_prefix;
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Ktraits>
inline __device__ void load_input(typename Ktraits::input_t* u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
    typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen)
{
    if constexpr (Ktraits::kIsEvenLen)
    {
        auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
        using vec_t = typename Ktraits::vec_t;
        Ktraits::BlockLoadVecT(smem_load_vec)
            .Load(reinterpret_cast<vec_t*>(u), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals));
    }
    else
    {
        Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
    }
}

template <typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t* Bvar,
    typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
    typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, int seqlen)
{
    constexpr int kNItems = Ktraits::kNItems;
    typename Ktraits::input_t B_vals_load[kNItems];
    if constexpr (Ktraits::kIsEvenLen)
    {
        auto& smem_load_weight_vec
            = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
        using vec_t = typename Ktraits::vec_t;
        Ktraits::BlockLoadWeightVecT(smem_load_weight_vec)
            .Load(reinterpret_cast<vec_t*>(Bvar), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load));
    }
    else
    {
        Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
    }
    // #pragma unroll
    // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
    Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
}

template <typename Ktraits>
inline __device__ void store_output(typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems],
    typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen)
{
    typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
    for (int i = 0; i < Ktraits::kNItems; ++i)
    {
        write_vals[i] = out_vals[i];
    }
    if constexpr (Ktraits::kIsEvenLen)
    {
        auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
        using vec_t = typename Ktraits::vec_t;
        Ktraits::BlockStoreVecT(smem_store_vec)
            .Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals));
    }
    else
    {
        Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
    }
}

} // namespace kernels
} // namespace tensorrt_llm
